"""Plotting functions for Simulation C.

This script generates two figures:

1. **Log‑slope plot:** For each mark probability ``p`` it plots
   ``log(V/V0)`` against ``N`` and overlays the predicted linear
   dependence with slope ``log(1-p)``.  Points are medians across seeds
   with ±68 % confidence intervals as error bars.
2. **Collapse plot:** It plots the measured ratio ``V/V0`` against
   ``p*N`` for all parameter combinations and seeds, overlaying the
   universal function ``exp(-p*N)``.

Optionally ablation data can be provided, in which case additional
curves are drawn for each ablation type on both figures.
"""

from __future__ import annotations

import argparse
import math
from pathlib import Path
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from .metrics import bootstrap_ci


def plot_log_slopes(df: pd.DataFrame, out_path: Path, df_ab: Optional[pd.DataFrame] = None) -> None:
    # Exclude p=0 baseline from log plot
    df = df[df["p"] > 0.0]
    p_values = sorted(df["p"].unique())
    fig, ax = plt.subplots(figsize=(7, 5))
    colors = plt.cm.viridis(np.linspace(0, 1, len(p_values)))
    for p, color in zip(p_values, colors):
        sub = df[df["p"] == p]
        # compute median ratio across seeds for each N
        med = sub.groupby("N")["V_over_V0"].median().reset_index()
        # compute error bars (68% CI) across seeds
        N_vals = med["N"].values.astype(float)
        V_med = med["V_over_V0"].values
        lows = []
        highs = []
        for Nval in N_vals:
            vals = sub[sub["N"] == Nval]["V_over_V0"]
            lo, hi = bootstrap_ci(vals, ci=0.68)
            lows.append(lo)
            highs.append(hi)
        log_y = np.log(np.maximum(V_med, 1e-12))
        # predicted line
        pred_line = np.log(1.0 - p) * N_vals
        ax.errorbar(N_vals, log_y, yerr=[log_y - np.log(np.maximum(lows, 1e-12)), np.log(np.maximum(highs, 1e-12)) - log_y], fmt="o", color=color)
        ax.plot(N_vals, pred_line, linestyle="--", color=color, label=f"p={p}")
    # Ablations overlay (optional)
    if df_ab is not None and not df_ab.empty:
        for ab_name in df_ab["ablation"].unique():
            df_ab_sub = df_ab[df_ab["ablation"] == ab_name]
            # Same colors per p but dashed lines to differentiate
            for p, color in zip(p_values, colors):
                sub_ab = df_ab_sub[df_ab_sub["p"] == p]
                med_ab = sub_ab.groupby("N")["V_over_V0"].median().reset_index()
                ax.plot(med_ab["N"].values, np.log(np.maximum(med_ab["V_over_V0"].values, 1e-12)), linestyle="dotted", color=color, label=f"{ab_name} p={p}")
    ax.set_xlabel("Number of mark opportunities N")
    ax.set_ylabel("log(V/V0)")
    ax.set_title("Log attenuation vs N with predicted slopes")
    ax.grid(True, ls="--", alpha=0.5)
    ax.legend(loc="best", ncol=2, fontsize="small")
    fig.tight_layout()
    fig.savefig(out_path, dpi=150)
    plt.close(fig)


def plot_collapse(df: pd.DataFrame, out_path: Path, df_ab: Optional[pd.DataFrame] = None) -> None:
    # Exclude p=0 baseline from collapse
    df = df[df["p"] > 0.0]
    x = (df["p"].values * df["N"].values).astype(float)
    y = df["V_over_V0"].astype(float).values
    fig, ax = plt.subplots(figsize=(7, 5))
    # Plot predicted universal curve y=exp(-x)
    x_line = np.linspace(0, max(x) * 1.05 if len(x) else 1, 200)
    ax.plot(x_line, np.exp(-x_line), color="black", linestyle="--", label="Prediction exp(-pN)")
    # Plot measured points for each p with different colours
    p_values = sorted(df["p"].unique())
    colors = plt.cm.viridis(np.linspace(0, 1, len(p_values)))
    for p, color in zip(p_values, colors):
        sub = df[df["p"] == p]
        ax.scatter(sub["p"] * sub["N"], sub["V_over_V0"], color=color, label=f"p={p}")
    # Ablations overlay (optional)
    if df_ab is not None and not df_ab.empty:
        for ab_name in df_ab["ablation"].unique():
            df_ab_sub = df_ab[df_ab["ablation"] == ab_name]
            for p, color in zip(p_values, colors):
                sub_ab = df_ab_sub[df_ab_sub["p"] == p]
                ax.scatter(sub_ab["p"] * sub_ab["N"], sub_ab["V_over_V0"], marker="x", color=color, label=f"{ab_name} p={p}")
    ax.set_xlabel("p × N")
    ax.set_ylabel("V/V0")
    ax.set_title("Collapse plot: measured attenuation vs pN")
    ax.set_ylim(0, 1.05)
    ax.grid(True, ls="--", alpha=0.5)
    # reduce duplicate labels in legend
    handles, labels = ax.get_legend_handles_labels()
    unique = []
    seen = set()
    for h, l in zip(handles, labels):
        if l not in seen:
            unique.append((h, l))
            seen.add(l)
    handles, labels = zip(*unique)
    ax.legend(handles, labels, loc="best", fontsize="small", ncol=2)
    fig.tight_layout()
    fig.savefig(out_path, dpi=150)
    plt.close(fig)


def main(summary_path: str, output_dir: str, ablation_path: Optional[str] = None) -> None:
    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)
    df = pd.read_csv(summary_path)
    df_ab = pd.read_csv(ablation_path) if ablation_path else None
    # Generate plots
    plot_log_slopes(df, out / "simC_log_slopes.png", df_ab)
    plot_collapse(df, out / "simC_collapse.png", df_ab)
    print(f"Figures saved to {out}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate plots for Simulation C")
    parser.add_argument("--summary", required=True, help="Path to simC_summary.csv")
    parser.add_argument("--output_dir", required=True, help="Directory to save figures")
    parser.add_argument("--ablation", default=None, help="Path to simC_ablation.csv (optional)")
    args = parser.parse_args()
    main(args.summary, args.output_dir, args.ablation)